{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inspect Datasets and Save as Smol Objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "from rdkit import Chem, RDLogger\n",
    "from torchmetrics import MetricCollection\n",
    "from rdkit.Chem.Draw import IPythonConsole\n",
    "IPythonConsole.ipython_3d = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import semlaflow.util.rdkit as smolRD\n",
    "import semlaflow.util.functional as smolF\n",
    "import semlaflow.util.metrics as Metrics\n",
    "from semlaflow.util.tokeniser import Vocabulary\n",
    "from semlaflow.util.molrepr import GeometricMol, GeometricMolBatch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "QM9_PATH = \"../../../data/qm9\"\n",
    "RAW_DIR =\"raw\"\n",
    "SPLIT_DIR = \"raw_split\"\n",
    "SAVE_DIR = \"smol\"\n",
    "SDF_FILE = \"gdb9.sdf\"\n",
    "METADATA_FILE = \"gdb9.sdf.csv\"\n",
    "SKIP_FILE = \"uncharacterized.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copied from MiDi code, so should create the same splits (they didn't make them available)\n",
    "def split_qm9(metadata_df):\n",
    "    n_samples = len(metadata_df)\n",
    "    n_train = 100000\n",
    "    n_test = int(0.1 * n_samples)\n",
    "    n_val = n_samples - (n_train + n_test)\n",
    "\n",
    "    # Shuffle dataset with df.sample, then split\n",
    "    train, val, test = np.split(metadata_df.sample(frac=1, random_state=42), [n_train, n_val + n_train])\n",
    "    return train, val, test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Will skip mol indices which appear in the skip file\n",
    "def rdkit_mols_from_df(split_path, sdf_path, skip_path):\n",
    "    target_df = pd.read_csv(split_path, index_col=0)\n",
    "    target_df.drop(columns=['mol_id'], inplace=True)\n",
    "\n",
    "    with open(skip_path, 'r') as f:\n",
    "        skip = [int(x.split()[0]) - 1 for x in f.read().split('\\n')[9:-2]]\n",
    "\n",
    "    suppl = Chem.SDMolSupplier(str(sdf_path), removeHs=False, sanitize=False)\n",
    "\n",
    "    mols = []\n",
    "    all_smiles = []\n",
    "\n",
    "    errors = 0\n",
    "    skipped = 0\n",
    "\n",
    "    for i, mol in enumerate(tqdm(suppl)):\n",
    "        if i not in target_df.index:\n",
    "            continue\n",
    "\n",
    "        if i in skip:\n",
    "            skipped += 1\n",
    "            continue\n",
    "\n",
    "        try:\n",
    "            Chem.SanitizeMol(mol)\n",
    "            smiles = Chem.MolToSmiles(mol, isomericSmiles=False)\n",
    "        except:\n",
    "            smiles = None\n",
    "\n",
    "        if smiles is None:\n",
    "            errors += 1\n",
    "        else:\n",
    "            all_smiles.append(smiles)\n",
    "            mols.append(mol)\n",
    "\n",
    "    print(f\"Skipped {skipped} mols which where in skip file.\")\n",
    "    print(f\"Encountered {errors} molecules which failed sanitisation.\")\n",
    "    print(f\"Completed loading of dataset with {len(mols)} molecules.\")\n",
    "\n",
    "    return mols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vocab():\n",
    "    # Need to make sure PAD has index 0\n",
    "    special_tokens = [\"<PAD>\", \"<MASK>\"]\n",
    "    core_atoms = [\"H\", \"C\", \"N\", \"O\", \"F\", \"P\", \"S\", \"Cl\"]\n",
    "    other_atoms = [\"Br\", \"B\", \"Al\", \"Si\", \"As\", \"I\", \"Hg\", \"Bi\"]\n",
    "    tokens = special_tokens + core_atoms + other_atoms\n",
    "    return Vocabulary(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matching_smiles(rdkit_mol, smol_mol, vocab):\n",
    "    rdkit_mol2 = smol_mol.to_rdkit(vocab)\n",
    "    smi1 = smolRD.smiles_from_mol(rdkit_mol, canonical=True)\n",
    "    smi2 = smolRD.smiles_from_mol(rdkit_mol2, canonical=True)\n",
    "    return smi1 == smi2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## QM9"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split QM9 and load into separate CSVs\n",
    "\n",
    "I have copied the code from the MiDi paper and used the same random seed, so hopefully this will generate the same splits as they used. But they haven't provided their splits so we can't say for sure without these.\n",
    "\n",
    "This code just splits the csv file, which contains metadata and properties for each molecule. The full molecular coordinates are stored in a single sdf file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qm9_path = Path(QM9_PATH)\n",
    "dataset = pd.read_csv(qm9_path / RAW_DIR / METADATA_FILE)\n",
    "train, val, test = split_qm9(dataset)\n",
    "\n",
    "train_csv_path = qm9_path / SPLIT_DIR / \"train.csv\"\n",
    "val_csv_path = qm9_path / SPLIT_DIR / \"val.csv\"\n",
    "test_csv_path = qm9_path / SPLIT_DIR / \"test.csv\"\n",
    "\n",
    "# train.to_csv(train_csv_path)\n",
    "# val.to_csv(val_csv_path)\n",
    "# test.to_csv(test_csv_path)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Smol Datasets from RDKit Mols from SDF Files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RDLogger.DisableLog('rdApp.*')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = build_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sdf_path = qm9_path / RAW_DIR / SDF_FILE\n",
    "skip_path = qm9_path / RAW_DIR / SKIP_FILE\n",
    "\n",
    "print(\"Processing train data...\")\n",
    "train_mols = rdkit_mols_from_df(train_csv_path, sdf_path, skip_path)\n",
    "\n",
    "print(\"Processing val data...\")\n",
    "val_mols = rdkit_mols_from_df(val_csv_path, sdf_path, skip_path)\n",
    "\n",
    "print(\"Processing test data...\")\n",
    "test_mols = rdkit_mols_from_df(test_csv_path, sdf_path, skip_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Smol batches for ease of use later on\n",
    "train_batch = GeometricMolBatch([GeometricMol.from_rdkit(mol) for mol in train_mols])\n",
    "val_batch = GeometricMolBatch([GeometricMol.from_rdkit(mol) for mol in val_mols])\n",
    "test_batch = GeometricMolBatch([GeometricMol.from_rdkit(mol) for mol in test_mols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check it looks right\n",
    "print(\"Dataset sizes:\")\n",
    "print(len(train_batch))\n",
    "print(len(val_batch))\n",
    "print(len(test_batch))\n",
    "\n",
    "example_mol = train_batch[567]\n",
    "print()\n",
    "print(\"Example mol:\")\n",
    "print(example_mol.coords)\n",
    "print(example_mol.atomics)\n",
    "print(example_mol.bonds)\n",
    "print(example_mol.charges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example_mol.to_rdkit(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for atom in example_mol.to_rdkit(vocab).GetAtoms():\n",
    "    print(f\"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = qm9_path / SAVE_DIR / \"train.smol\"\n",
    "val_path = qm9_path / SAVE_DIR / \"val.smol\"\n",
    "test_path = qm9_path / SAVE_DIR / \"test.smol\"\n",
    "\n",
    "train_bytes = train_batch.to_bytes()\n",
    "val_bytes = val_batch.to_bytes()\n",
    "test_bytes = test_batch.to_bytes()\n",
    "\n",
    "train_path.write_bytes(train_bytes)\n",
    "val_path.write_bytes(val_bytes)\n",
    "test_path.write_bytes(test_bytes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_matching = [matching_smiles(mol1, mol2, vocab) for mol1, mol2 in zip(train_mols, train_batch.to_list())]\n",
    "print(\"Proportion matching\", sum(train_matching) / len(train_matching))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train_mols))\n",
    "print(len(train_batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unmatched_idxs = [idx for idx, matching in enumerate(train_matching) if not matching]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 100\n",
    "unmatched_idx = unmatched_idxs[idx]\n",
    "print(smolRD.smiles_from_mol(train_mols[unmatched_idx]))\n",
    "print(smolRD.smiles_from_mol(train_batch[unmatched_idx].to_rdkit(vocab)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_valid = [smolRD.mol_is_valid(mol.to_rdkit(vocab)) for mol in train_batch]\n",
    "print(\"Propertion valid\", sum(train_valid) / len(train_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyse QM9 Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_coords = train_batch.coords\n",
    "train_mask = train_batch.mask\n",
    "\n",
    "_, std_dev = smolF.standardise_coords(train_coords, train_mask)\n",
    "print(\"Coord std dev on train data\", std_dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_n_atoms = sum(train_batch.seq_length) / len(train_batch.seq_length)\n",
    "max_n_atoms = max(train_batch.seq_length)\n",
    "min_n_atoms = min(train_batch.seq_length)\n",
    "print(\"avg\", avg_n_atoms)\n",
    "print(\"max\", max_n_atoms)\n",
    "print(\"min\", min_n_atoms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(train_batch.seq_length, bins=26)\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Firstly, try loading the saved data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_DIR = \"smol\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qm9_path = Path(QM9_PATH)\n",
    "train_path = qm9_path / SAVE_DIR / \"train.smol\"\n",
    "val_path = qm9_path / SAVE_DIR / \"val.smol\"\n",
    "test_path = qm9_path / SAVE_DIR / \"test.smol\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_bytes = train_path.read_bytes()\n",
    "val_bytes = val_path.read_bytes()\n",
    "test_bytes = test_path.read_bytes()\n",
    "\n",
    "train_batch = GeometricMolBatch.from_bytes(train_bytes)\n",
    "val_batch = GeometricMolBatch.from_bytes(val_bytes)\n",
    "test_batch = GeometricMolBatch.from_bytes(test_bytes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = build_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_mols = train_batch.to_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_mols[567].to_rdkit(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for atom in sample_mols[567].to_rdkit(vocab).GetAtoms():\n",
    "    print(f\"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_metrics = {\n",
    "    \"validity\": Metrics.Validity(),\n",
    "    \"fc-validity\": Metrics.Validity(connected=True),\n",
    "    \"uniqueness\": Metrics.Uniqueness(),\n",
    "    \"energy-validity\": Metrics.EnergyValidity(),\n",
    "    \"opt-energy-validity\": Metrics.EnergyValidity(optimise=True),\n",
    "    \"energy\": Metrics.AverageEnergy(),\n",
    "    \"energy-per-atom\": Metrics.AverageEnergy(per_atom=True),\n",
    "    \"strain\": Metrics.AverageStrainEnergy(),\n",
    "    \"strain-per-atom\": Metrics.AverageStrainEnergy(per_atom=True),\n",
    "    \"opt-rmsd\": Metrics.AverageOptRmsd()\n",
    "}\n",
    "gen_metrics = MetricCollection(gen_metrics, compute_groups=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute benchmark metrics on loaded train dataset samples\n",
    "rdkit_sample_mols = [mol.to_rdkit(vocab, sanitize=True) for mol in sample_mols]\n",
    "gen_metrics.reset()\n",
    "gen_metrics.update(rdkit_sample_mols)\n",
    "results = gen_metrics.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric, result in results.items():\n",
    "    print(f\"{metric} -- {result.item():.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute benchmark metrics on original train dataset samples\n",
    "gen_metrics.reset()\n",
    "gen_metrics.update(train_mols)\n",
    "results = gen_metrics.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for metric, result in results.items():\n",
    "    print(f\"{metric} -- {result.item():.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, mol in enumerate(sample_mols[82008:82010]):\n",
    "    print(idx)\n",
    "    mol.to_rdkit(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 82000 + 8\n",
    "sample_mols[idx].to_rdkit(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_mol = Chem.Mol(train_mols[idx])\n",
    "Chem.SanitizeMol(original_mol)\n",
    "original_mol"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Recreate this issue with functions in the notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mol_from_atoms(coords, tokens, bonds):\n",
    "    try:\n",
    "        atomics = [smolRD.PT.atomic_from_symbol(token) for token in tokens]\n",
    "    except:\n",
    "        return None\n",
    "\n",
    "    # Add atom types\n",
    "    mol = Chem.EditableMol(Chem.Mol())\n",
    "    for atomic in atomics:\n",
    "        mol.AddAtom(Chem.Atom(atomic))\n",
    "\n",
    "    # Add 3D coords\n",
    "    conf = Chem.Conformer(coords.shape[0])\n",
    "    for idx, coord in enumerate(coords.tolist()):\n",
    "        conf.SetAtomPosition(idx, coord)\n",
    "\n",
    "    mol = mol.GetMol()\n",
    "    mol.AddConformer(conf)\n",
    "\n",
    "    # Add bonds if they have been provided\n",
    "    mol = Chem.EditableMol(mol)\n",
    "    for bond in bonds.astype(np.int32).tolist():\n",
    "        start, end, b_type = bond\n",
    "\n",
    "        if b_type not in smolRD.IDX_BOND_MAP:\n",
    "            return None\n",
    "\n",
    "        # Don't add self connections\n",
    "        if start != end:\n",
    "            b_type = smolRD.IDX_BOND_MAP[b_type]\n",
    "            mol.AddBond(start, end, b_type)\n",
    "\n",
    "    mol = mol.GetMol()\n",
    "    for atom in mol.GetAtoms():\n",
    "        atom.UpdatePropertyCache(strict=False)\n",
    "\n",
    "    # try:\n",
    "    #     Chem.SanitizeMol(mol)\n",
    "    # except:\n",
    "    #     return None\n",
    "\n",
    "    return mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_rdkit(mol, vocab):\n",
    "    if len(mol.atomics.size()) == 2:\n",
    "        vocab_indices = torch.argmax(mol.atomics, dim=1).tolist()\n",
    "        tokens = vocab.tokens_from_indices(vocab_indices)\n",
    "\n",
    "    else:\n",
    "        atomics = mol.atomics.tolist()\n",
    "        tokens = [smolRD.PT.symbol_from_atomic(a) for a in atomics]\n",
    "\n",
    "    coords = mol.coords.numpy()\n",
    "    bonds = mol.bonds.numpy()\n",
    "\n",
    "    rdkit_mol = mol_from_atoms(coords, tokens, bonds)\n",
    "    return rdkit_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 82000 + 8\n",
    "problem_mol = sample_mols[idx]\n",
    "rdkit_mol = to_rdkit(problem_mol, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for atom in rdkit_mol.GetAtoms():\n",
    "    print(f\"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}\")\n",
    "\n",
    "print()\n",
    "for atom in original_mol.GetAtoms():\n",
    "    print(f\"Atom {atom.GetSymbol()} -- charge {atom.GetFormalCharge()} -- valence {atom.GetExplicitValence()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fegnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
